from os.path import join, splitext

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from .base import Trainer


class CoTeachingTrainer(Trainer):
    def __init__(
        self,
        config,
        model,
        model2,
        logger,
        train_set,
        test_set,
        criterion,
        optimizer,
        criterion2,
        optimizer2,
        num_gradual=10,
        forget_rate=0.2,
        exponet=1,
        mom1=0.9,
        mom2=0.1,
        epoch_decay_start=40,
        scheduler=None,
        val_set=None,
    ):
        super().__init__(
            config,
            model,
            logger,
            train_set,
            test_set,
            criterion,
            optimizer,
            scheduler,
            val_set,
        )
        self.model2 = model2
        self.optimizer2 = optimizer2
        self.criterion2 = criterion2
        self.num_gradual = num_gradual
        self.forget_rate = forget_rate
        self.exponet = exponet
        self.rate_schedule = np.ones(self.epoch) * self.forget_rate
        self.rate_schedule[: self.num_gradual] = np.linspace(
            0, self.forget_rate**self.exponet, self.num_gradual
        )
        self.mom1 = mom1
        self.mom2 = mom2
        self.epoch_decay_start = epoch_decay_start

        self.alpha_plan = [self.config["train"]["learning_rate"]] * self.epoch
        self.beta1_plan = [self.mom1] * self.epoch
        for i in range(self.epoch_decay_start, self.epoch):
            self.alpha_plan[i] = (
                float(self.epoch - i)
                / (self.epoch - self.epoch_decay_start)
                * self.config["train"]["learning_rate"]
            )
            self.beta1_plan[i] = self.mom2

    def save_model(self, file_name, second_model=False):
        if ".pth" not in file_name or ".pt" not in file_name:
            file_name += ".pth"
        if second_model:
            file_name, ext = splitext(file_name)
            file_name += "_2"
            file_name += ext
            torch.save(
                self.model2.state_dict(),
                join(self.config["general"]["save_model_dir"], file_name),
            )
        else:
            torch.save(
                self.model.state_dict(),
                join(self.config["general"]["save_model_dir"], file_name),
            )
        print(f"Model saved as {file_name}")

    def save_best_model(self, second_model=False):
        if second_model:
            self.save_model("best", second_model=True)
        else:
            self.save_model("best", second_model=False)

    def save_last_model(self, second_model=False):
        if second_model:
            self.save_model("last", second_model=True)
        else:
            self.save_model("last", second_model=False)

    def adjust_learning_rate(self, epoch, second_model=False):
        if second_model:
            try:
                optimizer = self.optimizer2
            except Exception as e:
                print("There is no second model. Still testing the first model.")
                optimizer = self.optimizer
        else:
            optimizer = self.optimizer

        for param_group in optimizer.param_groups:
            param_group["lr"] = self.alpha_plan[epoch]
            param_group["betas"] = (self.beta1_plan[epoch], 0.999)  # Only change beta1

    def run(self):
        print("==> Start training..")
        best_acc_1, best_acc_2 = 0.0, 0.0
        for cur_epoch in range(self.epoch):
            self.model.train()
            self.model2.train()
            self.adjust_learning_rate(cur_epoch)
            self.adjust_learning_rate(cur_epoch, second_model=True)
            train_total_1, train_total_2 = 0.0, 0.0
            epoch_loss_1, epoch_loss_2 = 0.0, 0.0
            train_correct_1, train_correct_2 = 0.0, 0.0
            with tqdm(self.train_loader, unit="batch") as tepoch:
                for data in tepoch:
                    tepoch.set_description(f"Epoch {cur_epoch}")
                    inputs, labels, attributes, idx = self.prepare_data(data)
                    logits_1 = self.model(inputs)
                    prec_1, _ = self.accuracy(logits_1, labels, topk=(1, 5))
                    train_total_1 += 1
                    train_correct_1 += prec_1
                    loss_1 = F.cross_entropy(logits_1, labels, reduce=False)
                    index_1_sorted = torch.argsort(loss_1.data).cuda()
                    loss_1_sorted = loss_1[index_1_sorted]

                    logits_2 = self.model2(inputs)
                    prec_2, _ = self.accuracy(logits_2, labels, topk=(1, 5))
                    train_total_2 += 1
                    train_correct_2 += prec_2
                    loss_2 = F.cross_entropy(logits_2, labels, reduce=False)
                    index_2_sorted = torch.argsort(loss_2.data).cuda()

                    remember_rate = 1 - self.rate_schedule[cur_epoch]
                    num_remember = int(remember_rate * len(loss_1_sorted))

                    ind_1_update = index_1_sorted[:num_remember]
                    ind_2_update = index_2_sorted[:num_remember]

                    loss_1_update = F.cross_entropy(
                        logits_1[ind_2_update], labels[ind_2_update]
                    )
                    loss_2_update = F.cross_entropy(
                        logits_2[ind_1_update], labels[ind_1_update]
                    )

                    loss_1 = torch.sum(loss_1_update)
                    loss_2 = torch.sum(loss_2_update)

                    self.optimizer.zero_grad()
                    loss_1.backward()
                    self.optimizer.step()
                    self.optimizer2.zero_grad()
                    loss_2.backward()
                    self.optimizer2.step()

                    tepoch.set_postfix(
                        loss_1=loss_1.item(),
                        accuracy_1=float(train_correct_1) / train_total_1,
                        loss_2=loss_2.item(),
                        accuracy_2=float(train_correct_2) / train_total_2,
                        lr=self.get_lr(),
                    )
                    self.global_iter += 1
                    epoch_loss_1 += loss_1
                    epoch_loss_2 += loss_2

                    if (
                        self.global_iter % self.config["general"]["logger"]["frequency"]
                        == 0
                    ):
                        self.logger.info(
                            f"[{cur_epoch}]/[{self.epoch}], Global Iter: {self.global_iter}, Loss 1: {loss_1:.4f}, Loss 2: {loss_2:.4f}, Acc 1: {float(train_correct_1) / train_total_1:.4f}, Acc 2: {float(train_correct_2) / train_total_2:.4f}, lr: {self.get_lr():.6f}",
                            {
                                "cur_epoch": cur_epoch,
                                "iter": self.global_iter,
                                "loss 1": loss_1.item(),
                                "loss 2": loss_2.item(),
                                "Accuracy 1": float(train_correct_1) / train_total_1,
                                "Accuracy 2": float(train_correct_2) / train_total_2,
                                "lr": self.get_lr(),
                            },
                        )

                if self.val_set:
                    _ = self.evaluate(val=True)
                    _ = self.evaluate(val=True, second_model=True)
                test_acc_1 = self.evaluate(val=False)
                test_acc_2 = self.evaluate(val=False, second_model=True)

                if test_acc_1 > best_acc_1:
                    best_acc_1 = test_acc_1
                    self.save_best_model()

                if test_acc_2 > best_acc_2:
                    best_acc_2 = test_acc_2
                    self.save_best_model(second_model=True)

                print(
                    f"Epoch: {cur_epoch}, Loss 1: {epoch_loss_1:.6f}, Loss 2: {epoch_loss_2:.6f}, Train Acc 1: {(float(train_correct_1) / train_total_1):.4f}, Train Acc 2: {(float(train_correct_2) / train_total_2):.4f}, Test Acc 1: {test_acc_1:.4f}, Test Acc 2: {test_acc_2:.4f}, Best Test Acc 1: {best_acc_1:.4f}, Best Test Acc 2: {best_acc_2:.4f}"
                )
                epoch_loss_1 /= train_total_1
                epoch_loss_2 /= train_total_2
                self.logger.info(
                    f"Epoch: {cur_epoch}, Loss 1: {epoch_loss_1:.6f}, Loss 2: {epoch_loss_2:.6f}, Train Acc 1: {(float(train_correct_1) / train_total_1):.4f}, Train Acc 2: {(float(train_correct_2) / train_total_2):.4f}, Test Acc 1: {test_acc_1:.4f}, Test Acc 2: {test_acc_2:.4f}, Best Test Acc 1: {best_acc_1:.4f}, Best Test Acc 2: {best_acc_2:.4f}",
                    {
                        "test_epoch": cur_epoch,
                        "loss 1": epoch_loss_1.item(),
                        "loss 2": epoch_loss_2.item(),
                        "Train Acc 1": (float(train_correct_1) / train_total_1),
                        "Train Acc 2": (float(train_correct_2) / train_total_2),
                        "Test Acc 1": test_acc_1,
                        "Test Acc 2": test_acc_2,
                        "Best Test Acc 1": best_acc_1,
                        "Best Test Acc 2": best_acc_2,
                    },
                )
            if cur_epoch % self.save_every_epoch == 0:
                self.save_model(f"{cur_epoch}")
                self.save_model(f"{cur_epoch}", second_model=True)
            self.save_last_model()
            self.save_last_model(second_model=True)
